"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN."""

import argparse
import glob
import sys
import threading
import traceback

import PIL.Image
import numpy as np
# pylint: disable=too-many-lines
import os
import six.moves.queue as Queue  # pylint: disable=import-error
import tensorflow as tf

import dnnlib.tflib as tflib
from training import dataset
#from scipy.misc import imresize


# ----------------------------------------------------------------------------


def error(msg):
    print("Error: " + msg)
    exit(1)


# ----------------------------------------------------------------------------


class TFRecordExporter:
    def __init__(
            self, tfrecord_dir, expected_images, res_log2=7, print_progress=True, progress_interval=10
    ):
        self.tfrecord_dir = tfrecord_dir
        self.tfr_prefix = os.path.join(
            self.tfrecord_dir, os.path.basename(self.tfrecord_dir)
        )
        self.expected_images = expected_images
        self.cur_images = 0
        self.shape = None
        # self.resolution_log2 = None
        self.res_log2 = res_log2
        self.tfr_writers = []
        self.print_progress = print_progress
        self.progress_interval = progress_interval

        if self.print_progress:
            print('Creating dataset "%s"' % tfrecord_dir)
        if not os.path.isdir(self.tfrecord_dir):
            os.makedirs(self.tfrecord_dir)
        assert os.path.isdir(self.tfrecord_dir)

    def close(self):
        if self.print_progress:
            print("%-40s\r" % "Flushing data...", end="", flush=True)
        for tfr_writer in self.tfr_writers:
            tfr_writer.close()
        self.tfr_writers = []
        if self.print_progress:
            print("%-40s\r" % "", end="", flush=True)
            print("Added %d images." % self.cur_images)

    def choose_shuffled_order(
            self,
    ):  # Note: Images and labels must be added in shuffled order.
        order = np.arange(self.expected_images)
        np.random.RandomState(123).shuffle(order)
        return order

    def add_image_raw(self, encoded_jpg):
        if self.print_progress and self.cur_images % self.progress_interval == 0:
            print(
                "%d / %d\r" % (self.cur_images, self.expected_images),
                end="",
                flush=True,
            )
        for lod, tfr_writer in enumerate(self.tfr_writers):
            ex = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "shape": tf.train.Feature(
                            int64_list=tf.train.Int64List(value=self.shape)
                        ),
                        "img": tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_jpg]))
                    }
                )
            )
            tfr_writer.write(ex.SerializeToString())
        self.cur_images += 1
        
    def create_tfr_writer(self, shape):
        self.shape = [shape[2], shape[0], shape[1]]
        assert self.shape[0] in [1, 3]
        assert self.shape[1] % (2 ** self.res_log2) == 0
        assert self.shape[2] % (2 ** self.res_log2) == 0
        tfr_opt = tf.python_io.TFRecordOptions(
            tf.python_io.TFRecordCompressionType.NONE
        )
        tfr_file = self.tfr_prefix + "-r%02d.tfrecords" % (
                    self.res_log2
        )
        self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))


    def add_image(self, img):
        if self.print_progress and self.cur_images % self.progress_interval == 0:
            print(
                "%d / %d\r" % (self.cur_images, self.expected_images),
                end="",
                flush=True,
            )
        if self.shape is None:
            self.shape = img.shape
            #self.resolution_log2 = int(np.log2(self.shape[1]))
            assert self.shape[0] in [1, 3]
            #assert self.shape[1] == self.shape[2]
            #assert self.shape[1] == 2 ** self.resolution_log2
            assert self.shape[1] % (2 ** self.res_log2) == 0
            assert self.shape[2] % (2 ** self.res_log2) == 0
            tfr_opt = tf.python_io.TFRecordOptions(
                tf.python_io.TFRecordCompressionType.NONE
            )
            tfr_file = self.tfr_prefix + "-r%02d.tfrecords" % (
                        self.res_log2
                )
            self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
            """
            for lod in range(self.resolution_log2 - 1):
                tfr_file = self.tfr_prefix + "-r%02d.tfrecords" % (
                        self.resolution_log2 - lod
                )
                self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
            """
        assert img.shape == self.shape
        for lod, tfr_writer in enumerate(self.tfr_writers):
            if lod:
                img = img.astype(np.float32)
                img = (
                              img[:, 0::2, 0::2]
                              + img[:, 0::2, 1::2]
                              + img[:, 1::2, 0::2]
                              + img[:, 1::2, 1::2]
                      ) * 0.25
            quant = np.rint(img).clip(0, 255).astype(np.uint8)
            ex = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "shape": tf.train.Feature(
                            int64_list=tf.train.Int64List(value=quant.shape)
                        ),
                        "data": tf.train.Feature(
                            bytes_list=tf.train.BytesList(value=[quant.tostring()])
                        ),
                    }
                )
            )
            tfr_writer.write(ex.SerializeToString())
        self.cur_images += 1

    def add_labels(self, labels):
        if self.print_progress:
            print("%-40s\r" % "Saving labels...", end="", flush=True)
        assert labels.shape[0] == self.cur_images
        with open(self.tfr_prefix + "-rxx.labels", "wb") as f:
            np.save(f, labels.astype(np.float32))

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()


# ----------------------------------------------------------------------------


class ExceptionInfo(object):
    def __init__(self):
        self.value = sys.exc_info()[1]
        self.traceback = traceback.format_exc()


# ----------------------------------------------------------------------------


class WorkerThread(threading.Thread):
    def __init__(self, task_queue):
        threading.Thread.__init__(self)
        self.task_queue = task_queue

    def run(self):
        while True:
            func, args, result_queue = self.task_queue.get()
            if func is None:
                break
            try:
                result = func(*args)
            except:
                result = ExceptionInfo()
            result_queue.put((result, args))


# ----------------------------------------------------------------------------


class ThreadPool(object):
    def __init__(self, num_threads):
        assert num_threads >= 1
        self.task_queue = Queue.Queue()
        self.result_queues = dict()
        self.num_threads = num_threads
        for _idx in range(self.num_threads):
            thread = WorkerThread(self.task_queue)
            thread.daemon = True
            thread.start()

    def add_task(self, func, args=()):
        assert hasattr(func, "__call__")  # must be a function
        if func not in self.result_queues:
            self.result_queues[func] = Queue.Queue()
        self.task_queue.put((func, args, self.result_queues[func]))

    def get_result(self, func):  # returns (result, args)
        result, args = self.result_queues[func].get()
        if isinstance(result, ExceptionInfo):
            print("\n\nWorker thread caught an exception:\n" + result.traceback)
            raise result.value
        return result, args

    def finish(self):
        for _idx in range(self.num_threads):
            self.task_queue.put((None, (), None))

    def __enter__(self):  # for 'with' statement
        return self

    def __exit__(self, *excinfo):
        self.finish()

    def process_items_concurrently(
            self,
            item_iterator,
            process_func=lambda x: x,
            pre_func=lambda x: x,
            post_func=lambda x: x,
            max_items_in_flight=None,
    ):
        if max_items_in_flight is None:
            max_items_in_flight = self.num_threads * 4
        assert max_items_in_flight >= 1
        results = []
        retire_idx = [0]

        def task_func(in_prepared, _idx):
            return process_func(in_prepared)

        def retire_result():
            processed, (_prepared, in_idx) = self.get_result(task_func)
            results[in_idx] = processed
            while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
                yield post_func(results[retire_idx[0]])
                results[retire_idx[0]] = None
                retire_idx[0] += 1

        for idx, item in enumerate(item_iterator):
            prepared = pre_func(item)
            results.append(None)
            self.add_task(func=task_func, args=(prepared, idx))
            while retire_idx[0] < idx - max_items_in_flight + 2:
                for res in retire_result():
                    yield res
        while retire_idx[0] < len(results):
            for res in retire_result():
                yield res


# ----------------------------------------------------------------------------


def display(tfrecord_dir):
    print('Loading dataset "%s"' % tfrecord_dir)
    tflib.init_tf({"gpu_options.allow_growth": True})
    dset = dataset.TFRecordDataset(
        tfrecord_dir, max_label_size="full", repeat=False, shuffle_mb=0
    )
    tflib.init_uninitialized_vars()
    import cv2  # pip install opencv-python

    idx = 0
    while True:
        try:
            images, labels = dset.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            break
        if idx == 0:
            print("Displaying images")
            cv2.namedWindow("dataset_tool")
            print("Press SPACE or ENTER to advance, ESC to exit")
        print("\nidx = %-8d\nlabel = %s" % (idx, labels[0].tolist()))
        cv2.imshow(
            "dataset_tool", images[0].transpose(1, 2, 0)[:, :, ::-1]
        )  # CHW => HWC, RGB => BGR
        idx += 1
        if cv2.waitKey() == 27:
            break
    print("\nDisplayed %d images." % idx)


# ----------------------------------------------------------------------------


def extract(tfrecord_dir, output_dir):
    print('Loading dataset "%s"' % tfrecord_dir)
    tflib.init_tf({"gpu_options.allow_growth": True})
    dset = dataset.TFRecordDataset(
        tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0
    )
    tflib.init_uninitialized_vars()

    print('Extracting images to "%s"' % output_dir)
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    idx = 0
    while True:
        if idx % 10 == 0:
            print("%d\r" % idx, end="", flush=True)
        try:
            images, _labels = dset.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            break
        if images.shape[1] == 1:
            img = PIL.Image.fromarray(images[0][0], "L")
        else:
            img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), "RGB")
        img.save(os.path.join(output_dir, "img%08d.png" % idx))
        idx += 1
    print("Extracted %d images." % idx)


# ----------------------------------------------------------------------------


def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
    max_label_size = 0 if ignore_labels else "full"
    print('Loading dataset "%s"' % tfrecord_dir_a)
    tflib.init_tf({"gpu_options.allow_growth": True})
    dset_a = dataset.TFRecordDataset(
        tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0
    )
    print('Loading dataset "%s"' % tfrecord_dir_b)
    dset_b = dataset.TFRecordDataset(
        tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0
    )
    tflib.init_uninitialized_vars()

    print("Comparing datasets")
    idx = 0
    identical_images = 0
    identical_labels = 0
    while True:
        if idx % 100 == 0:
            print("%d\r" % idx, end="", flush=True)
        try:
            images_a, labels_a = dset_a.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            images_a, labels_a = None, None
        try:
            images_b, labels_b = dset_b.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            images_b, labels_b = None, None
        if images_a is None or images_b is None:
            if images_a is not None or images_b is not None:
                print("Datasets contain different number of images")
            break
        if images_a.shape == images_b.shape and np.all(images_a == images_b):
            identical_images += 1
        else:
            print("Image %d is different" % idx)
        if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
            identical_labels += 1
        else:
            print("Label %d is different" % idx)
        idx += 1
    print("Identical images: %d / %d" % (identical_images, idx))
    if not ignore_labels:
        print("Identical labels: %d / %d" % (identical_labels, idx))


# ----------------------------------------------------------------------------


def create_mnist(tfrecord_dir, mnist_dir):
    print('Loading MNIST from "%s"' % mnist_dir)
    import gzip

    with gzip.open(os.path.join(mnist_dir, "train-images-idx3-ubyte.gz"), "rb") as file:
        images = np.frombuffer(file.read(), np.uint8, offset=16)
    with gzip.open(os.path.join(mnist_dir, "train-labels-idx1-ubyte.gz"), "rb") as file:
        labels = np.frombuffer(file.read(), np.uint8, offset=8)
    images = images.reshape(-1, 1, 28, 28)
    images = np.pad(
        images, [(0, 0), (0, 0), (2, 2), (2, 2)], "constant", constant_values=0
    )
    assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
    assert labels.shape == (60000,) and labels.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 9
    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
    onehot[np.arange(labels.size), labels] = 1.0

    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            tfr.add_image(images[order[idx]])
        tfr.add_labels(onehot[order])


# ----------------------------------------------------------------------------


def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
    print('Loading MNIST from "%s"' % mnist_dir)
    import gzip

    with gzip.open(os.path.join(mnist_dir, "train-images-idx3-ubyte.gz"), "rb") as file:
        images = np.frombuffer(file.read(), np.uint8, offset=16)
    images = images.reshape(-1, 28, 28)
    images = np.pad(images, [(0, 0), (2, 2), (2, 2)], "constant", constant_values=0)
    assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255

    with TFRecordExporter(tfrecord_dir, num_images) as tfr:
        rnd = np.random.RandomState(random_seed)
        for _idx in range(num_images):
            tfr.add_image(images[rnd.randint(images.shape[0], size=3)])


# ----------------------------------------------------------------------------


def create_cifar10(tfrecord_dir, cifar10_dir):
    print('Loading CIFAR-10 from "%s"' % cifar10_dir)
    import pickle

    images = []
    labels = []
    for batch in range(1, 6):
        with open(os.path.join(cifar10_dir, "data_batch_%d" % batch), "rb") as file:
            data = pickle.load(file, encoding="latin1")
        images.append(data["data"].reshape(-1, 3, 32, 32))
        labels.append(data["labels"])
    images = np.concatenate(images)
    labels = np.concatenate(labels)
    assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
    assert labels.shape == (50000,) and labels.dtype == np.int32
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 9
    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
    onehot[np.arange(labels.size), labels] = 1.0

    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            tfr.add_image(images[order[idx]])
        tfr.add_labels(onehot[order])


# ----------------------------------------------------------------------------


def create_cifar100(tfrecord_dir, cifar100_dir):
    print('Loading CIFAR-100 from "%s"' % cifar100_dir)
    import pickle

    with open(os.path.join(cifar100_dir, "train"), "rb") as file:
        data = pickle.load(file, encoding="latin1")
    images = data["data"].reshape(-1, 3, 32, 32)
    labels = np.array(data["fine_labels"])
    assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
    assert labels.shape == (50000,) and labels.dtype == np.int32
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 99
    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
    onehot[np.arange(labels.size), labels] = 1.0

    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            tfr.add_image(images[order[idx]])
        tfr.add_labels(onehot[order])


# ----------------------------------------------------------------------------


def create_svhn(tfrecord_dir, svhn_dir):
    print('Loading SVHN from "%s"' % svhn_dir)
    import pickle

    images = []
    labels = []
    for batch in range(1, 4):
        with open(os.path.join(svhn_dir, "train_%d.pkl" % batch), "rb") as file:
            data = pickle.load(file, encoding="latin1")
        images.append(data[0])
        labels.append(data[1])
    images = np.concatenate(images)
    labels = np.concatenate(labels)
    assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8
    assert labels.shape == (73257,) and labels.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 9
    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
    onehot[np.arange(labels.size), labels] = 1.0

    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            tfr.add_image(images[order[idx]])
        tfr.add_labels(onehot[order])


# ----------------------------------------------------------------------------


def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
    print('Loading LSUN dataset from "%s"' % lmdb_dir)
    import lmdb  # pip install lmdb # pylint: disable=import-error
    import cv2  # pip install opencv-python
    import io

    with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
        total_images = txn.stat()["entries"]  # pylint: disable=no-value-for-parameter
        if max_images is None:
            max_images = total_images
        with TFRecordExporter(tfrecord_dir, max_images) as tfr:
            for _idx, (_key, value) in enumerate(txn.cursor()):
                try:
                    try:
                        img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
                        if img is None:
                            raise IOError("cv2.imdecode failed")
                        img = img[:, :, ::-1]  # BGR => RGB
                    except IOError:
                        img = np.asarray(PIL.Image.open(io.BytesIO(value)))
                    crop = np.min(img.shape[:2])
                    img = img[
                          (img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
                          (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2,
                          ]
                    img = PIL.Image.fromarray(img, "RGB")
                    img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
                    img = np.asarray(img)
                    img = img.transpose([2, 0, 1])  # HWC => CHW
                    tfr.add_image(img)
                except:
                    print(sys.exc_info()[1])
                if tfr.cur_images == max_images:
                    break


# ----------------------------------------------------------------------------


def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None):
    assert width == 2 ** int(np.round(np.log2(width)))
    assert height <= width
    print('Loading LSUN dataset from "%s"' % lmdb_dir)
    import lmdb  # pip install lmdb # pylint: disable=import-error
    import cv2  # pip install opencv-python
    import io

    with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
        total_images = txn.stat()["entries"]  # pylint: disable=no-value-for-parameter
        if max_images is None:
            max_images = total_images
        with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr:
            for idx, (_key, value) in enumerate(txn.cursor()):
                try:
                    try:
                        img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
                        if img is None:
                            raise IOError("cv2.imdecode failed")
                        img = img[:, :, ::-1]  # BGR => RGB
                    except IOError:
                        img = np.asarray(PIL.Image.open(io.BytesIO(value)))

                    ch = int(np.round(width * img.shape[0] / img.shape[1]))
                    if img.shape[1] < width or ch < height:
                        continue

                    img = img[(img.shape[0] - ch) // 2: (img.shape[0] + ch) // 2]
                    img = PIL.Image.fromarray(img, "RGB")
                    img = img.resize((width, height), PIL.Image.ANTIALIAS)
                    img = np.asarray(img)
                    img = img.transpose([2, 0, 1])  # HWC => CHW

                    canvas = np.zeros([3, width, width], dtype=np.uint8)
                    canvas[:, (width - height) // 2: (width + height) // 2] = img
                    tfr.add_image(canvas)
                    print(
                        "\r%d / %d => %d " % (idx + 1, total_images, tfr.cur_images),
                        end="",
                    )

                except:
                    print(sys.exc_info()[1])
                if tfr.cur_images == max_images:
                    break
    print()


# ----------------------------------------------------------------------------


def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
    print('Loading CelebA from "%s"' % celeba_dir)
    glob_pattern = os.path.join(celeba_dir, "img_align_celeba_png", "*.png")
    image_filenames = sorted(glob.glob(glob_pattern))
    expected_images = 202599
    if len(image_filenames) != expected_images:
        error("Expected to find %d images" % expected_images)

    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            assert img.shape == (218, 178, 3)
            img = img[cy - 64: cy + 64, cx - 64: cx + 64]
            img = img.transpose(2, 0, 1)  # HWC => CHW
            tfr.add_image(img)


# ----------------------------------------------------------------------------


def _get_all_files(path):
    if os.path.isfile(path):
        return [path]

    possible_files = sorted(glob.glob(os.path.join(path, "*")))
    return_list = []
    for possible_file in possible_files:
        return_list.extend(_get_all_files(possible_file))
    return return_list


def create_from_images(tfrecord_dir, image_dir, shuffle, res_log2=7, resize=None):
    print('Loading images from "%s"' % image_dir)
    image_filenames = _get_all_files(image_dir)
    print(f"detected {len(image_filenames)} images ...")
    if len(image_filenames) == 0:
        error("No input images found")
    img = np.asarray(PIL.Image.open(image_filenames[0]))
    #resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1
    """
    if resize is None:
        if img.shape[1] != resolution:
            error("Input images must have the same width and height")
        if resolution != 2 ** int(np.floor(np.log2(resolution))):
            error("Input image resolution must be a power-of-two")
    """
    if channels not in [1, 3]:
        error("Input images must be stored as RGB or grayscale")

    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = (
            tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        )
        print("Adding the images to tfrecords ...")
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            if resize is not None:
                size = int(2 ** resize)
                #img = imresize(img, (size, size))
                img = np.array(Image.fromarray(img).resize((size, size))) 
            if channels == 1:
                img = img[np.newaxis, :, :]  # HW => CHW
            else:
                img = img.transpose([2, 0, 1])  # HWC => CHW
            if img.shape[0] > 3:
                img = img[:3, ...]
            tfr.add_image(img)

def create_from_images_raw(tfrecord_dir, image_dir, shuffle, res_log2=7, resize=None):
    print('Loading images from "%s"' % image_dir)
    image_filenames = _get_all_files(image_dir)
    print(f"detected {len(image_filenames)} images ...")
    if len(image_filenames) == 0:
        error("No input images found")
    img = np.asarray(PIL.Image.open(image_filenames[0]))
    #resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1
    
    if channels not in [1, 3]:
        error("Input images must be stored as RGB or grayscale")
    if shuffle:
        print("Shuffle the images...")
    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = (
            tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        )
        tfr.create_tfr_writer(img.shape)
        print("Adding the images to tfrecords ...")
        for idx in range(order.size):
            if idx % 1000 == 0:
                print ("added images", idx)
            with tf.gfile.FastGFile(image_filenames[order[idx]], 'rb') as fid:
                encoded_jpg = fid.read()
                tfr.add_image_raw(encoded_jpg)
# ----------------------------------------------------------------------------


def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
    print('Loading HDF5 archive from "%s"' % hdf5_filename)
    import h5py  # conda install h5py

    with h5py.File(hdf5_filename, "r") as hdf5_file:
        hdf5_data = max(
            [value for key, value in hdf5_file.items() if key.startswith("data")],
            key=lambda lod: lod.shape[3],
        )
        with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:
            order = (
                tfr.choose_shuffled_order()
                if shuffle
                else np.arange(hdf5_data.shape[0])
            )
            for idx in range(order.size):
                tfr.add_image(hdf5_data[order[idx]])
            npy_filename = os.path.splitext(hdf5_filename)[0] + "-labels.npy"
            if os.path.isfile(npy_filename):
                tfr.add_labels(np.load(npy_filename)[order])


# ----------------------------------------------------------------------------


def execute_cmdline(argv):
    prog = argv[0]
    parser = argparse.ArgumentParser(
        prog=prog,
        description="Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.",
        epilog='Type "%s <command> -h" for more information.' % prog,
    )

    subparsers = parser.add_subparsers(dest="command")
    subparsers.required = True

    def add_command(cmd, desc, example=None):
        epilog = "Example: %s %s" % (prog, example) if example is not None else None
        return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)

    p = add_command("display", "Display images in dataset.", "display datasets/mnist")
    p.add_argument("tfrecord_dir", help="Directory containing dataset")

    p = add_command(
        "extract", "Extract images from dataset.", "extract datasets/mnist mnist-images"
    )
    p.add_argument("tfrecord_dir", help="Directory containing dataset")
    p.add_argument("output_dir", help="Directory to extract the images into")

    p = add_command(
        "compare", "Compare two datasets.", "compare datasets/mydataset datasets/mnist"
    )
    p.add_argument("tfrecord_dir_a", help="Directory containing first dataset")
    p.add_argument("tfrecord_dir_b", help="Directory containing second dataset")
    p.add_argument(
        "--ignore_labels", help="Ignore labels (default: 0)", type=int, default=0
    )

    p = add_command(
        "create_mnist",
        "Create dataset for MNIST.",
        "create_mnist datasets/mnist ~/downloads/mnist",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("mnist_dir", help="Directory containing MNIST")

    p = add_command(
        "create_mnistrgb",
        "Create dataset for MNIST-RGB.",
        "create_mnistrgb datasets/mnistrgb ~/downloads/mnist",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("mnist_dir", help="Directory containing MNIST")
    p.add_argument(
        "--num_images",
        help="Number of composite images to create (default: 1000000)",
        type=int,
        default=1000000,
    )
    p.add_argument(
        "--random_seed", help="Random seed (default: 123)", type=int, default=123
    )

    p = add_command(
        "create_cifar10",
        "Create dataset for CIFAR-10.",
        "create_cifar10 datasets/cifar10 ~/downloads/cifar10",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("cifar10_dir", help="Directory containing CIFAR-10")

    p = add_command(
        "create_cifar100",
        "Create dataset for CIFAR-100.",
        "create_cifar100 datasets/cifar100 ~/downloads/cifar100",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("cifar100_dir", help="Directory containing CIFAR-100")

    p = add_command(
        "create_svhn",
        "Create dataset for SVHN.",
        "create_svhn datasets/svhn ~/downloads/svhn",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("svhn_dir", help="Directory containing SVHN")

    p = add_command(
        "create_lsun",
        "Create dataset for single LSUN category.",
        "create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("lmdb_dir", help="Directory containing LMDB database")
    p.add_argument(
        "--resolution", help="Output resolution (default: 256)", type=int, default=256
    )
    p.add_argument(
        "--max_images",
        help="Maximum number of images (default: none)",
        type=int,
        default=None,
    )

    p = add_command(
        "create_lsun_wide",
        "Create LSUN dataset with non-square aspect ratio.",
        "create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("lmdb_dir", help="Directory containing LMDB database")
    p.add_argument("--width", help="Output width (default: 512)", type=int, default=512)
    p.add_argument(
        "--height", help="Output height (default: 384)", type=int, default=384
    )
    p.add_argument(
        "--max_images",
        help="Maximum number of images (default: none)",
        type=int,
        default=None,
    )

    p = add_command(
        "create_celeba",
        "Create dataset for CelebA.",
        "create_celeba datasets/celeba ~/downloads/celeba",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("celeba_dir", help="Directory containing CelebA")
    p.add_argument(
        "--cx", help="Center X coordinate (default: 89)", type=int, default=89
    )
    p.add_argument(
        "--cy", help="Center Y coordinate (default: 121)", type=int, default=121
    )

    p = add_command(
        "create_from_images",
        "Create dataset from a directory full of images. Please be careful"
        "since the tool recursively searches inside every sub-directory for image files",
        "create_from_images datasets/mydataset myimagedir",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("image_dir", help="Directory containing the images")
    p.add_argument(
        "--resize",
        help="resize to given power of 2 sized square images (default: None)",
        type=int,
        default=None,
        required=False
    )
    p.add_argument(
        "--shuffle", help="Randomize image order (default: 1)", type=int, default=1
    )
    p.add_argument(
        "--res_log2",
        help="image width and height should be multiple of 2**res_log2 (default: 7)",
        type=int,
        default=7
    )
    
    p = add_command(
        "create_from_images_raw",
        "Create dataset from a directory full of images. Please be careful"
        "since the tool recursively searches inside every sub-directory for image files",
        "create_from_images_raw datasets/mydataset myimagedir",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("image_dir", help="Directory containing the images")
    p.add_argument(
        "--resize",
        help="resize to given power of 2 sized square images (default: None)",
        type=int,
        default=None,
        required=False
    )
    p.add_argument(
        "--shuffle", help="Randomize image order (default: 1)", type=int, default=1
    )
    p.add_argument(
        "--res_log2",
        help="image width and height should be multiple of 2**res_log2 (default: 7)",
        type=int,
        default=7
    )
    
    p = add_command(
        "create_from_hdf5",
        "Create dataset from legacy HDF5 archive.",
        "create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5",
    )
    p.add_argument("tfrecord_dir", help="New dataset directory to be created")
    p.add_argument("hdf5_filename", help="HDF5 archive containing the images")
    p.add_argument(
        "--shuffle", help="Randomize image order (default: 1)", type=int, default=1
    )

    args = parser.parse_args(argv[1:] if len(argv) > 1 else ["-h"])
    func = globals()[args.command]
    del args.command
    func(**vars(args))


# ----------------------------------------------------------------------------

if __name__ == "__main__":
    execute_cmdline(sys.argv)

# ----------------------------------------------------------------------------
